/******************************************************************************
 *  Copyright (c) 2016, Xilinx, Inc.
 *  All rights reserved.
 *
 *  Redistribution and use in source and binary forms, with or without
 *  modification, are permitted provided that the following conditions are met:
 *
 *  1.  Redistributions of source code must retain the above copyright notice,
 *     this list of conditions and the following disclaimer.
 *
 *  2.  Redistributions in binary form must reproduce the above copyright
 *      notice, this list of conditions and the following disclaimer in the
 *      documentation and/or other materials provided with the distribution.
 *
 *  3.  Neither the name of the copyright holder nor the names of its
 *      contributors may be used to endorse or promote products derived from
 *      this software without specific prior written permission.
 *
 *  THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *  AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
 *  THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 *  PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR
 *  CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
 *  EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
 *  PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
 *  OR BUSINESS INTERRUPTION). HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
 *  WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
 *  OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
 *  ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 *****************************************************************************/
#pragma once
#include "tiny_cnn/layers/layer.h"
#include "tiny_cnn/activations/activation_function.h"
#include "tiny_cnn/util/util.h"
#include <vector>
#include <string>
#include <iostream>

namespace tiny_cnn {
class bnn_threshold_layer : public layer<activation::identity> {
public:
    typedef layer<activation::identity> Base;

    // channels: number of channels. each channel has a separate threshold
    // dim: number of pixels/elements in each channel.
    bnn_threshold_layer(cnn_size_t channels, cnn_size_t dim = 1, std::string binaryParamFile = "")
        : Base(dim*channels, dim*channels, 0, 0), dim_(dim), channels_(channels),
          thresholds_(channels, 0), invertOutput_(channels, false)
    {
      // TODO re-enable parallelization -- need to support worker index in forward prop
      set_parallelize(false);
      if(binaryParamFile != "")
        loadFromBinaryFile(binaryParamFile);
    }

    void loadFromBinaryFile(std::string fileName) {
      // does not support setting invertOutput but should not be necessary anyway --
      // bin weight files are generated by Python script that flips the weights when
      // inverted output is needed
      // TODO this assumes the binary file always uses 8 bytes per threshold entry

      // load thresholds
      std::ifstream tf(fileName, std::ios::binary | std::ios::in);
      if(!tf.is_open())
        throw "Could not open file";
      for(unsigned int line = 0 ; line < channels_; line++) {
        unsigned long long e = 0;
        tf.read((char *)&e, sizeof(unsigned long long));
        thresholds_[line] = e;
      }
      tf.close();
    }

    std::vector<int> & thresholds() {
      return thresholds_;
    }

    std::vector<bool> & invertOutput() {
      return invertOutput_;
    }

    size_t connection_size() const override {
        return in_size_;
    }

    size_t fan_in_size() const override {
        return dim_;
    }

    size_t fan_out_size() const override {
        return dim_;
    }

    const vec_t& forward_propagation(const vec_t& in, size_t index) override {
        vec_t &out = output_[index];

        for(unsigned int ch = 0; ch < channels_; ch++) {
          for(unsigned int j = 0; j < dim_; j++) {
              unsigned int pos = ch*dim_ + j;
              out[pos] = (in[pos] > thresholds_[ch] ? +1 : -1);

              if(invertOutput_[ch])
                  out[pos] = -out[pos];
          }
        }

        return next_ ? next_->forward_propagation(out, index) : out;
    }

    const vec_t& back_propagation(const vec_t& curr_delta, size_t index) override {
        throw "Not yet implemented";
        return curr_delta;
    }

    const vec_t& back_propagation_2nd(const vec_t& current_delta2) override {
        throw "Not yet implemented";
        return current_delta2;
    }

    std::string layer_type() const override { return "bnn_threshold_layer"; }

protected:
    unsigned int dim_;
    unsigned int channels_;


    std::vector<int> thresholds_;
    std::vector<bool> invertOutput_;
};

} // namespace tiny_cnn
